package cn.edu.hitsz.compiler.asm;

import cn.edu.hitsz.compiler.NotImplementedException;
import cn.edu.hitsz.compiler.ir.IRImmediate;
import cn.edu.hitsz.compiler.ir.IRValue;
import cn.edu.hitsz.compiler.ir.IRVariable;
import cn.edu.hitsz.compiler.ir.Instruction;
import cn.edu.hitsz.compiler.utils.FileUtils;

import java.util.*;


/**
 * TODO: 实验四: 实现汇编生成
 * <br>
 * 在编译器的整体框架中, 代码生成可以称作后端, 而前面的所有工作都可称为前端.
 * <br>
 * 在前端完成的所有工作中, 都是与目标平台无关的, 而后端的工作为将前端生成的目标平台无关信息
 * 根据目标平台生成汇编代码. 前后端的分离有利于实现编译器面向不同平台生成汇编代码. 由于前后
 * 端分离的原因, 有可能前端生成的中间代码并不符合目标平台的汇编代码特点. 具体到本项目你可以
 * 尝试加入一个方法将中间代码调整为更接近 risc-v 汇编的形式, 这样会有利于汇编代码的生成.
 * <br>
 * 为保证实现上的自由, 框架中并未对后端提供基建, 在具体实现时可自行设计相关数据结构.
 *
 * @see AssemblyGenerator#run() 代码生成与寄存器分配
 */
public class AssemblyGenerator {
    private List<Instruction> instuctions;
    private List<String> asm;

    /**寄存器描述符 地址描述符 空闲的寄存器 未使用的内存地址记录 每个变量的最后活跃时间 涉及的所有变量 未开辟的最低内存地址*/
    private Map<String, IRVariable> registerDescriptor;
    private Map<IRVariable, String> addressDescriptor;
    private List<String> freeReg;
    private List<String> freeAddr;
    private Map<IRVariable, Integer> expire;
    private Set<IRVariable> vars;
    private int offset;

    /** 取用的寄存器号 */
    private int regCount = 0;

    /**缩进 */
    private String tab = "    ";

    /**寄存器锁
     * 被锁住的寄存器号将不能被抢占
     * 例如ir (op, $1, $2, $3) 中
     * 按从左到右的顺序为三个变量分配寄存器
     * 后者不能抢占前者的寄存器 故在为右边的变量分配寄存器时需要对左边已确定的寄存器加锁
     * **/
    private int lock1 = -1;
    private int lock2 = -1;

    public AssemblyGenerator() {
        registerDescriptor = new HashMap<>();
        addressDescriptor = new HashMap<>();
        freeReg = new ArrayList<>(Arrays.asList("t0","t1","t2","t3","t4","t5","t6"));
        freeAddr = new ArrayList<>();
        offset = 0;
    }

    /**为寄存器加锁 */
    public void aquireLock(String reg){
        if(isReg(reg)){
            // 如果lock1已被占用则使用lock2
            if(lock1 >= 0){
                lock2 = reg.charAt(1) - '0';
            }
            else {
                lock1 = reg.charAt(1) - '0';
            }
        }
        else {
            throw new RuntimeException("Illegal register!");
        }
    }

    /**释放所有寄存器锁 */
    public void releaseLock(){
        lock1 = -1;
        lock2 = -1;
    }


    /** 判断一个字符串是否是寄存器的表示 */
    private boolean isReg(String memo){
        return memo.startsWith("t");
    }

    /**文法符号映射到常量或内存地址 */
    private String mapping(IRValue val){
        if(val.isImmediate()){
            return val.toString();
        }
        else {
            return getReg((IRVariable) val);
        }
    }

    /**从空闲寄存器中取出一个寄存器 */
    private String popReg(){

        if(freeReg.size() <= 0){
            throw new RuntimeException("register pressure!");
        }
        return freeReg.remove(0);
    }

    /**获取当前轮需要交换出去的寄存器 因为每个寄存器存的都是一个int 所以认为代价相同 使用计数轮询方式溢出 跳过加锁的寄存器 */
    private String getRegForSwap(){
        /**跳过被加锁的寄存器 */
        while (regCount == lock1 || regCount ==  lock2){
            ++regCount;
            regCount %= 7;
        }
        String res =  "t" + regCount;
        ++regCount;
        regCount %= 7;
        return res;
    }

    /** 为变量获取一个空闲寄存器 */
    private String getReg(IRVariable var){
        String reg = "";

        // 如果var已被保存在了寄存器过内存中
        if(addressDescriptor.containsKey(var)){
            String memo = addressDescriptor.get(var);
            // 如果已经存在寄存器里了 直接返回这个寄存器
            if(isReg(memo)){
                return memo;
            }
            //  如果存在了内存地址中
            else {
                // 如果当前有空闲寄存器 直接加载进空闲寄存器 并更新相应的描述符
                if(freeReg.size()>0){
                    reg = popReg();
                    asm.add(new AssemblyInstruction("lw",reg,memo).toString());
                    registerDescriptor.put(reg,var);
                    addressDescriptor.put(var,reg);
                    return reg;
                }
                // 否则将一个寄存器中存放的内容和内存中的地址交换 tx <-> memo
                else {
                    String tx = getRegForSwap();
                    String newMemo = offset + "(x0)";
                    asm.add(new AssemblyInstruction("sw",tx,newMemo).toString());
                    asm.add(new AssemblyInstruction("lw",tx,memo).toString());

                    // 获取t0寄存器原本保存的内容
                    IRVariable or = registerDescriptor.get(tx);

                    // 更新t0寄存器所保存的内容 更新or和var的最新保存地址
                    registerDescriptor.put(tx,var);
                    addressDescriptor.put(var,tx);
                    addressDescriptor.put(or,newMemo);
                    // 更新安全偏移值
                    offset += 4;
                    return tx;

                }
            }
        }

        // 未保存在寄存器中过则为其分配新的寄存器
        // 如果有空闲寄存器 取一个 更新描述符
        if(freeReg.size() > 0){
            reg = popReg();
            registerDescriptor.put(reg,var);
            addressDescriptor.put(var,reg);
        }
        // 没有空闲的寄存器 将tx寄存器中存放的数据溢出到内存，然后将tx分配给这个变量 同时修改描述符
        else {
            String tx = getRegForSwap();
            String memo = offset + "(x0)";
            asm.add(new AssemblyInstruction("sw",tx,memo).toString());
            // 获取tx寄存器原本保存的内容
            IRVariable or = registerDescriptor.get(tx);

            // 更新tx寄存器所保存的内容 更新or和var的最新保存地址
            registerDescriptor.put(tx,var);
            addressDescriptor.put(var,tx);
            addressDescriptor.put(or,memo);
            // 更新安全偏移值
            offset += 4;
            reg = tx;

        }
        return reg;
    }


    /**根据变量的最后活跃时间对寄存器进行重新分配 */
    private void realloc(int count){
        // 遍历每个变量
        for(IRVariable val : vars){
            // 如果变量过期了 将该变量原本占用的寄存器空闲出来
            if (count > expire.get(val)){
                String reg = addressDescriptor.remove(val);
                if(reg != null && isReg(reg)){
                    freeReg.add(reg);
                    registerDescriptor.remove(reg);
                }
            }
        }

    }

    /**
     * 加载前端提供的中间代码
     * <br>
     * 视具体实现而定, 在加载中或加载后会生成一些在代码生成中会用到的信息. 如变量的引用
     * 信息. 这些信息可以通过简单的映射维护, 或者自行增加记录信息的数据结构.
     *
     * @param originInstructions 前端提供的中间代码
     */
    public void loadIR(List<Instruction> originInstructions) {
        this.instuctions = originInstructions;
        expire = new HashMap<>();
        int count = 1;

        // 统计每个变量的失效时间
        for(Instruction inst : instuctions){
            switch (inst.getKind()){
                case RET -> {
                    IRValue retrunValue = inst.getReturnValue();
                    if(!retrunValue.isImmediate()){
                        expire.put((IRVariable) retrunValue,count);
                    }
                }
                case MOV -> {
                    IRVariable res = inst.getResult();
                    IRValue from = inst.getFrom();

                    expire.put(res,count);
                    if(!from.isImmediate()){
                        expire.put((IRVariable)from,count);
                    }
                }
                default -> {
                    IRVariable result = inst.getResult();
                    expire.put(result,count);

                    IRValue lhs = inst.getLHS();
                    IRValue rhs = inst.getRHS();
                    if(!lhs.isImmediate()){
                        expire.put((IRVariable)lhs,count);
                    }
                    if(!rhs.isImmediate()){
                        expire.put((IRVariable)rhs,count);
                    }
                }
            }
            ++count;
        }
        // 收集分析过程中会用到的所有变量
        this.vars = expire.keySet();
    }


    /**
     * 执行代码生成.
     * <br>
     * 根据理论课的做法, 在代码生成时同时完成寄存器分配的工作. 若你觉得这样的做法不好,
     * 也可以将寄存器分配和代码生成分开进行.
     * <br>
     * 提示: 寄存器分配中需要的信息较多, 关于全局的与代码生成过程无关的信息建议在代码生
     * 成前完成建立, 与代码生成的过程相关的信息可自行设计数据结构进行记录并动态维护.
     */
    public void run() {
        // 初始化汇编结果
        asm = new ArrayList<>();
        // 记录汇编生成的轮数（以一条中间代码为一轮）
        int count = 1;

        for(Instruction inst: instuctions){
            // 重分配在本轮及以后不会活跃的变量占用的寄存器
            realloc(count);

            // 根据IR生成汇编代码
            switch (inst.getKind()){
                case MOV -> {
                    IRVariable res = inst.getResult();
                    IRValue from = inst.getFrom();
                    String dst,rs1;

                    // 获取dst寄存器位置 加锁
                    dst = getReg(res);
                    aquireLock(dst);

                    // 获取rs1 如果from是一个变量则直接li
                    if(from.isImmediate()){
                        rs1 = from.toString();
                        asm.add(new AssemblyInstruction("li",dst,rs1).setAnnotation(inst.toString()).toString());
                    }
                    // 如果from是变量的话则为其分配寄存器并使用mv指令
                    else {
                        rs1 = getReg((IRVariable) from);
                        asm.add(new AssemblyInstruction("mv",dst,rs1).setAnnotation(inst.toString()).toString());
                    }
                    releaseLock();
                }

                case ADD -> {
                    IRVariable result = inst.getResult();
                    IRValue lhs = inst.getLHS();
                    IRValue rhs = inst.getRHS();
                    String dst,rs1,rs2;

                    // 获取dst、rs1、rs2寄存器位置 加锁
                    dst = getReg(result);
                    aquireLock(dst);
                    rs1 = mapping(lhs);
                    aquireLock(rs1);
                    rs2 = mapping(rhs);

                    // 生成汇编代码 释放锁
                    asm.add(new AssemblyInstruction("add",dst,rs1,rs2).setAnnotation(inst.toString()).toString());
                    releaseLock();
                }

                case MUL -> {
                    IRVariable result = inst.getResult();
                    IRValue lhs = inst.getLHS();
                    IRValue rhs = inst.getRHS();
                    String dst,rs1,rs2;

                    // 获取dst、rs1、rs2寄存器位置 加锁
                    dst = getReg(result);
                    aquireLock(dst);
                    rs1 = mapping(lhs);
                    aquireLock(rs1);
                    rs2 = mapping(rhs);

                    // 生成汇编代码 释放锁
                    asm.add(new AssemblyInstruction("mul",dst,rs1,rs2).setAnnotation(inst.toString()).toString());
                    releaseLock();
                }

                case SUB -> {
                    IRVariable result = inst.getResult();
                    IRValue lhs = inst.getLHS();
                    IRValue rhs = inst.getRHS();
                    String dst,rs1,rs2;

                    // 获取dst、rs1、rs2寄存器位置 加锁
                    dst = getReg(result);
                    aquireLock(dst);
                    rs1 = mapping(lhs);
                    aquireLock(rs1);
                    rs2 = mapping(rhs);

                    // 生成汇编代码 释放锁
                    asm.add(new AssemblyInstruction("sub",dst,rs1,rs2).setAnnotation(inst.toString()).toString());
                    releaseLock();
                }

                case RET -> {
                    IRValue retrunValue = inst.getReturnValue();

                    // 根据返回结果是变量还是立即数生成汇编代码
                    if(retrunValue.isImmediate()){
                        asm.add(new AssemblyInstruction("li","a0",retrunValue.toString()).setAnnotation(inst.toString()).toString());
                    }
                    else {
                        asm.add(new AssemblyInstruction("mv","a0",getReg((IRVariable) retrunValue)).setAnnotation(inst.toString()).toString());
                    }
                }

                default -> {

                }
            }
            // 轮数++
            ++count;
        }
    }


    /**
     * 输出汇编代码到文件
     *
     * @param path 输出文件路径
     */
    public void dump(String path) {
        String res = ".text\n";
        for(String s : asm){
            res += tab;
            res += s;
        }
        FileUtils.writeFile(path,res);

    }
}

